from dataclasses import dataclass

from megatron.core.packed_seq_params import PackedSeqParams
from torch import Tensor


@dataclass
class GpatchPackedSeqParams(PackedSeqParams):
    '''
    MemoryEfficientAttention params
    '''

    use_zigzag: bool = True
    kv_slice: slice = None
    heads_k_stride: int = None
